Lab 12 - Deep Reinforcement Learning 1

Lab. 12 - Deep Reinforcement Learning cz. 1

Frozen Lake

1. Wstęp

Uczenie ze wzmocnieniem (Reinforcement Learning, RL)

RL polega na trenowaniu agentów, aby podejmowali sekwencje decyzji poprzez interakcję z otoczeniem. Agent stara się maksymalizować skumulowaną nagrodę w czasie. W przeciwieństwie do uczenia nadzorowanego, w RL agent nie otrzymuje par wejście-wyjście, ale musi samodzielnie eksplorować środowisko, aby dowiedzieć się, jakie akcje prowadzą do pożądanych rezultatów.

Stable-Baselines3

To biblioteka open-source implementująca popularne algorytmy uczenia ze wzmocnieniem w Pythonie. Oferuje prosty w użyciu interfejs oraz wsparcie dla najnowszych technik RL, co czyni ją idealnym narzędziem zarówno dla początkujących, jak i zaawansowanych użytkowników. Biblioteka zapewnia takie algorytmy jak PPO, A2C, SAC, DDPG, DQN i inne, wspierając integrację z różnymi środowiskami, w tym Gymnasium.

Deep Q-Learning

Deep Q-learning (DQN) rozszerza Q-learning, używając głębokiej sieci neuronowej do przybliżania funkcji Q. Głęboka sieć neuronowa przyjmuje stan otoczenia jako dane wejściowe i generuje wartości Q dla każdej możliwej akcji. Dzięki temu model może obsługiwać złożone przestrzenie stanów o wysokich wymiarach, co sprawia, że jest odpowiedni do zadań takich jak granie w gry wideo czy sterowanie robotami.

Kluczowe elementy Deep Q-learning

2. Cel zajęć

Celem zajęć jest poznanie podstaw metod głębokiego uczenia ze wzmocnieniem na przykładzie DQN.

3. Przygotowanie środowiska

Na dzisiejszych zajęciach nie będziemy wykorzystywać ROSa. Utwórz wirtualne środowisko Pythona:

python3 -m venv venv
source venv/bin/activate

lub w Visual Studio Code kombinacją klawiszy Ctrl+Shift+P i wpisując Python: Select Interpreter wybierz nowe środowisko wirtualne.

W ramach przygotowania środowiska należy zainstalować bibliotekę PyTorch w odpowiedniej wersji (w zależności od posiadania GPU lub wyłącznie CPU). Instrukcja instalacji jest dostępna tutaj. Nie ma konieczności instalacji torchaudio i torchvision.

Konieczna będzie także instalacja dodatkowych bibliotek:

pip install gymnasium stable-baselines3[extra] numpy wandb

Aktualnie zainstalowaną wersję torcha można sprawdzić:

import torch

print(torch.__version__)

Fakt możliwości używania GPU w PyTorch można sprawdzić:

import torch

print(torch.cuda.is_available())

4. Trening Deep Q-Learning

Należy rozpocząć od zaimportowania potrzebnych bibliotek:

import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor

oraz zainicjowania środowiska:

config = {
    "policy_type": "MlpPolicy",
    "total_timesteps": 1e3,
    "env_name": "FrozenLake-v1"
    "is_slippery": False,
}

def make_env():
    env = gym.make(config["env_name"], is_slippery=config["is_slippery"])
    env = Monitor(env)  # record stats such as returns
    return env

env = DummyVecEnv([make_env])

W dalszym kroku konieczne jest utworzenie modelu DQN, który następnie zostanie wytrenowany:

# Tworzenie modelu
model = DQN('MlpPolicy', env, verbose=1)

# Trenowanie modelu
model.learn(total_timesteps=config["total_timesteps"])

# Zapisanie modelu
model.save("dqn_frozen_lake")

UWAGA: W przypadku błędów z Qt należy dodać:

import os

os.environ.pop("QT_QPA_PLATFORM_PLUGIN_PATH")

Po zakończeniu treningu można sprawdzić działanie modelu:

# Wczytanie modelu
model = DQN.load("dqn_frozen_lake")

# Ponowne inicjalizowanie środowiska z wizualizacją
env = gym.make('FrozenLake-v1', is_slippery= config["is_slippery"], render_mode='human')
obs, _ = env.reset()

for i in range(1000):
    action, _states = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = env.step(int(action))
    env.render()
    if terminated or truncated:
        break

env.close()

Aktualnie nie wiemy jak działa nasz model. W celu oceny jego skuteczności można wykorzystać funkcję evaluate_policy (najlepiej z wyłączoną wizualizacją):

from stable_baselines3.common.evaluation import evaluate_policy

mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=100)
print(f"Mean reward: {mean_reward:.2f} +/- {std_reward:.2f}")

5. Analiza uczenia Deep Q-Learning

Na ten moment nasz model nie działa zbyt dobrze. W celu poprawy jego skuteczności można:

  1. Zwiększyć liczbę kroków treningu.
  2. Zwiększyć rozmiar sieci neuronowej.
  3. Zmienić algorytm uczenia.
  4. Zmienić hiperparametry.

Aby zwiększyć liczbę kroków treningu, można zmienić parametr total_timesteps w funkcji learn. Zwiększ wartość tego parametru do 1e5 (taki zapis oznacza jeden i pięć zer) oraz ponownie wytrenuj model. Po tym zabiegu skuteczność modelu powinna wzrosnąć.

Zmień teraz argument is_slippery na True i ponownie wytrenuj model. Jak ta zmiana wpłynęła na skuteczność modelu?

Nadal jednak trudno jest analizować proces uczenia. W celu lepszej analizy można wykorzystać narzędzia do monitorowania procesu uczenia, takie jak WandB. W celu skorzystania z tego narzędzia należy założyć konto na stronie WandB, a następnie zainicjować projekt. Konieczne jest podanie klucza API (swój klucz możesz znaleźć TUTAJ). Dodaj ten kod na początku skryptu:

import wandb
from wandb.integration.sb3 import WandbCallback

wandb.login(key="YOUR_KEY")
run = wandb.init(project="frozen_lake_arm", config=config, save_code=True, sync_tensorboard=True)

Następnie dodać logowanie do modelu:

model = DQN('MlpPolicy', env, verbose=1, tensorboard_log=f"runs/{run.id}")

Pozostaje tylko dodać odpowiedni callback do uczenia:

model.learn(total_timesteps=1e5, callback=WandbCallback(
    gradient_save_freq=100,
    verbose=2
))

Po uwzględnieniu wszystkich modyfikacji uruchom ponownie uczenie. Podczas uczenia, na stronie WandB pojawi się nowy projekt, w którym można monitorować postępy uczenia. Sprawdź, jakie parametry są tam dostępne i jakie informacje można z nich odczytać. Kiedy możemy stwierdzić, że proces uczenia przebiega pomyślnie? Kiedy możemy uznać trening za zakończony?

6. Zadanie do samodzielnej realizacji

W ramach rozwiązania należy przesłać opis wpływu parametrów (ich znaczenie oraz jak działa algorytm po ich zmianie) oraz kod źródłowy z drugiego zadania (z rozszerzeniem .txt).

Materiały dodatkowe


Autor: Kamil Młodzikowski